-
Notifications
You must be signed in to change notification settings - Fork 35
Use Threads.nthreads() * 2
in TSVI
#936
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Benchmark Report for Commit 6b9b332Computer Information
Benchmark Results
|
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #936 +/- ##
=======================================
Coverage 82.91% 82.91%
=======================================
Files 36 36
Lines 3962 3962
=======================================
Hits 3285 3285
Misses 677 677 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Pull Request Test Coverage Report for Build 15216502921Details
💛 - Coveralls |
1 similar comment
Pull Request Test Coverage Report for Build 15216502921Details
💛 - Coveralls |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happy with this as a hacky workaround, would only propose documenting why this is the way it is.
How would TuringLang/Turing.jl#2555 affect TSVI?
@@ -9,7 +9,9 @@ struct ThreadSafeVarInfo{V<:AbstractVarInfo,L} <: AbstractVarInfo | |||
logps::L | |||
end | |||
function ThreadSafeVarInfo(vi::AbstractVarInfo) | |||
return ThreadSafeVarInfo(vi, [Ref(zero(getlogp(vi))) for _ in 1:Threads.nthreads()]) | |||
return ThreadSafeVarInfo( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add a comment here explaining the situation and maybe linking to this PR or the relevant issue? The * 2
would otherwise appear quite mysterious.
By making it opt-in, it would allow us to use In contrast, now TSVI is mandatory whenever |
Oh right, because you can specify an AbstractVarInfo when you make a LogDensityFunction. Got it. |
Optimistically, I might be able to get that done by the end of this week, but if I find it's too hard I'll come back to this! |
See #924 for background.
This PR adopts a similar approach to solution (1), i.e., using
Threads.maxthreadid()
. However, Mooncake can't differentiatemaxthreadid()
, so this is a hacky workaround to a hacky workaround, based on the observation thatmaxthreadid()
seems to be upper-bounded bynthreads() * 2
.Personally, I would be much more in favour of removing TSVI (or making it opt-in), and will probably do this in a separate PR once TuringLang/Turing.jl#2555 is solved.
But this is a quick enough fix that should ensure that TSVI continues to work on Julia 1.12 (even if for the wrong reasons). Whether we merge this will probably depend on how fast that Turing issue can be fixed -- if it's soon, then we can remove TSVI, if it's not soon, then this can plug the gap.
Note that if TSVI is opt-in, then we can just use
maxthreadid()
, thereby removing one layer of hackiness — because Mooncake doesn't work with multithreaded execution anyway, the scenarios where you'd want to opt into TSVI have no overlap with the scenarios where you'd want to use Mooncake.